from numpy import pi, floor, sqrt, sin, cos
from math import acos, asin
import numpy as np
from scipy.interpolate import pchip_interpolate, CubicSpline
import matplotlib.pyplot as plt


def trans(point=np.zeros(3, dtype=float), angle=np.zeros(3, dtype=float), tran=np.zeros(3, dtype=float), mode=True):
	'''
	mode=True :p是点在{a}中的位置 p不动 将{a}坐标系变换到{b}坐标系后 p在{b}中的位置
	mode=False :p是点在{a}中的位置 返回p旋转平移后在{a}中的新位置
	:param point: p在{a}中的坐标
	:param angle: 三轴旋转角 弧度
	:param tran: 三轴平移量 米
	:param mode: 模式是{a}内变换 还是 {a}到{b}变换
	:return: p在{b}中的坐标 或 p在{a}中的新位置
	'''
	psi = angle[0]  # x
	phi = angle[1]  # y
	theta = angle[2]  # z
	# print('theta', theta)
	point = np.array([point[0], point[1], point[2], 1])

	Rx = np.array([[1, 0, 0, 0],
	               [0, cos(psi), sin(psi), 0],
	               [0, -sin(psi), cos(psi), 0],
	               [0, 0, 0, 1]])
	# print(Rx)
	Ry = np.array([[cos(phi), 0, -sin(phi), 0],
	               [0, 1, 0, 0],
	               [sin(phi), 0, cos(phi), 0],
	               [0, 0, 0, 1]])
	Rz = np.array([[cos(theta), sin(theta), 0, 0],
	               [-sin(theta), cos(theta), 0, 0],
	               [0, 0, 1, 0],
	               [0, 0, 0, 1]])
	t = np.array([[1, 0, 0, tran[0]],
	              [0, 1, 0, tran[1]],
	              [0, 0, 1, tran[2]],
	              [0, 0, 0, 1]])
	# print('Rz', Rz)
	Rxt = np.array([[1, 0, 0, 0],
	               [0, cos(psi), -sin(psi), 0],
	               [0, sin(psi), cos(psi), 0],
	               [0, 0, 0, 1]])
	# print(Rxt)
	Ryt = np.array([[cos(phi), 0, sin(phi), 0],
	               [0, 1, 0, 0],
	               [-sin(phi), 0, cos(phi), 0],
	               [0, 0, 0, 1]])
	Rzt = np.array([[cos(theta), -sin(theta), 0, 0],
	               [sin(theta), cos(theta), 0, 0],
	               [0, 0, 1, 0],
	               [0, 0, 0, 1]])
	if mode:
		T = np.dot(np.dot(np.dot(Rx, Ry), Rz), t)
	else:
		T = np.dot(np.dot(np.dot(Rxt, Ryt), Rzt), t)


	return np.dot(T, point)[0:3]


# 测试上面的函数
p_new = trans(point=np.array([1, 0, 0]), angle=np.array([0, 0, pi/2]), tran=np.array([0, 0, 0]), mode=False)
print('p_new', p_new)


# 角度限制, 输入弧度, 输出-pi~pi
def rad_limit(rad: float) -> float:
	rad = rad - 2 * pi * floor(rad / (2 * pi))
	if rad >= pi:
		rad = rad - 2 * pi
	if rad < -pi:
		rad = rad + 2 * pi
	return rad


# 求一般式直线交点
def get_cross_point(a=np.zeros(3), b=np.zeros(3)):
	point = np.zeros(3)
	flag = False
	try:
		if a[0] * b[1] == a[1] * b[0]:
			# 两直线平行 没有交点
			raise OSError('平行')
		else:
			x = (b[2] * a[1] - a[2] * b[1]) / (a[0] * b[1] - b[0] * a[1])
			y = (a[2] * b[0] - b[2] * a[0]) / (a[0] * b[1] - b[0] * a[1])
			point[0] = x
			point[1] = y
	except OSError:
		return point, flag
	flag = True
	return point, flag


def get_h_line(point, n):
	'''
	# 已知经过点point 平行直线的矢量 求该直线的一般方程
	:param point: 直线经过点1*3
	:param n: 直线法向量1*3
	:return: 直线Ax + By + C = 0的[A B C]
	'''
	eta = 1e-5
	coe = [0, 0, 0]
	if abs(voc_dot(n, np.array([0, 1, 0]))) < eta:
		# 如果u矢量垂直于vx轴
		print('与纵轴平行 本直线B=0')
		coe[0] = 1
		coe[1] = 0
		coe[2] = point[0]
	else:
		print('与纵轴不平行 B不为0')
		coe[0] = n[1] / n[0]
		coe[1] = -1
		coe[2] = point[1] - coe[0] * point[0]
	print('coe', coe)
	return coe


def vector_dot_angle(A, B) -> float:
	'''
	矢量夹角 从A向B的夹角 带正负号 右手顺时针为正 逆时针为负
	A、B都是三维矢量{x, y, z}，AB在同一个平面上

	:param A:
	:param B:
	:return: 夹角 -pi~pi
	'''
	val = (A[0] * B[0] + A[1] * B[1]) / (sqrt((A[0] ** 2) + (A[1] ** 2)) * sqrt((B[0] ** 2) + (B[1] ** 2)))
	# 小数位数太多 导致val大于1 acos无解
	angle = acos(round(val, 2))
	C = vec_cross(A, B)
	flag = 0
	if C[2] > 0:
		flag = 1
	else:
		flag = -1
	return flag * angle


def voc_dot(a=np.zeros(3), b=np.zeros(3)):
	c = a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
	return c


def vec_cross(a=np.zeros(3, dtype=float), b=np.zeros(3, dtype=float)) -> list:
	'''
	矢量叉乘 从a向b
	:param a:
	:param b:
	:return:
	'''
	c = np.zeros(3)
	c[0] = a[1] * b[2] - a[2] * b[1]
	c[1] = -a[0] * b[2] + b[0] * a[2]
	c[2] = a[0] * b[1] - a[1] * b[0]
	return c


# --------------------------SCIPY INTERPOLATE START--------------------------
def my_interpolate(x_observed, y_observed, method='pchip', p_num=50):
	'''

	:param x_observed:
	:param y_observed:
	:param method: pchip--三次埃特金插值 cubic--三次样条 liner--线性插值
	:return:
	'''
	x, y = [], []
	if method == 'pchip':
		x = np.linspace(min(x_observed), max(x_observed), num=p_num)
		y = pchip_interpolate(x_observed, y_observed, x)
	elif method == 'cubic':
		x = np.linspace(min(x_observed), max(x_observed), num=p_num)
		cs = CubicSpline(x_observed, y_observed)
		y = cs(x)
	return x, y


# 测试例子 Handbook of marinecraft... chaper10 section10.4.1 figure10.14
# 三次埃特金插值的波动更小 虽然没有样条插值光顺
# x_observed = [0, 100, 500, 700, 1000]
# y_observed = [0, 100, 100, 200, 160]
# px, py = my_interpolate(x_observed, y_observed, 'pchip')
# cx, cy = my_interpolate(x_observed, y_observed, 'cubic')
# plt.plot(y_observed, x_observed, "o", label="observation")
# plt.plot(py, px, label="pchip interpolation")
# plt.plot(cy, cx, label="cubicSpline interpolation")
# plt.xlabel('y east')
# plt.ylabel('x north')
# plt.legend()
# plt.show()
# --------------------------SCIPY INTERPOLATE END--------------------------


def norm(a):
	'''
    计算矢量模
    :param a: ndarray数组
    :return:
    '''
	result = 0
	for i in a:
		result += i ** 2
	return np.sqrt(result)

def get_circle(point, r):
	ta = np.linspace(0, 2 * pi, 360)
	x_c = point[0] + r * cos(ta)
	y_c = point[1] + r * sin(ta)
	return x_c, y_c

def in_or_out_poly(point, poly):
	'''
	判断点是否在多边形的内部 grapscan
	:param point:
	:param poly:
	:return:
	'''
	flag = False
	for i in range(len(poly)):
		a_p = poly[i]
		b_p = poly[i+1]
		vec_a = a_p - point
		vec_b = b_p - point
		vec_cross()


def in_half_plane(hp, point):
	'''
	判断点是否在半平面内
	:param hp: 半平面 ndarray 1*6
	:param point: 1*3 点
	:return: true在 or false不在
	'''
	# 如果直线上任一点到point的矢量与矢量n夹角绝对值小于 pi/2
	# 则点在半平面内
	flag = False
	A = hp[0]
	B = hp[1]
	C = hp[2]
	vec_n = [hp[3], hp[4], hp[5]]
	if B == 0:
		y = 0
		x = -C / A

	else:
		# 取x = 0
		x = 0
		y = -C / B
	vec = point - np.array([x, y, 0])
	angle = vector_dot_angle(vec, vec_n)
	if abs(angle) <= 0.5 * pi:
		flag = True
	else:
		flag = False

	return flag


def point_to_line(a, b):
	'''

	:param a: 直线方程
	:param b: 点
	:return: 距离
	'''
	# 点到直线距离公式
	dis = abs(a[0] * b[0] + a[1] * b[1] + a[2]) / sqrt(a[0] ** 2 + a[1] ** 2)
	return dis